{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 在上一讲中，我们展示了几种实现简易平均加权的方式\n",
    "    - 在这里，做一些一些准备工作，从而实现后续搭建自注意力模块"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BLM(nn.Module):\n",
    "    def __init__(self,vocab_size):\n",
    "        super().__init__()\n",
    "        self.token_embedding_table = nn.Embedding(vocab_size,vocab_size)\n",
    "    \n",
    "    def forward(self,idx,targets = None):\n",
    "\n",
    "        logits = self.token_embedding_table(idx) # (B,T)  -> (B,T,C)  # 这里我们通过Embedding操作直接得到预测分数\n",
    "        # 这里的预测分数过程与二分类或者多分类的分数是大致相同的\n",
    "\n",
    "        \n",
    "        if targets is None:\n",
    "            loss = None\n",
    "        else:   \n",
    "            B, T, C = logits.shape\n",
    "            logits = logits.view(B*T, C)\n",
    "            targets = targets.view(B*T)    # 这里我们调整一下形状，以符合torch的交叉熵损失函数对于输入的变量的要求\n",
    "            loss = F.cross_entropy(logits, targets)\n",
    "\n",
    "        return logits, loss\n",
    "\n",
    "    def generate(self, idx, max_new_tokens):\n",
    "        '''\n",
    "        idx 是现在的输入的(B, T)序列 ，这是之前我们提取的batch的下标\n",
    "        max_new_tokens 是产生的最大的tokens数量\n",
    "        '''\n",
    "\n",
    "        for _ in range(max_new_tokens):\n",
    "            \n",
    "            # 得到预测的结果\n",
    "            logits,_ = self(idx) # _ 表示省略，用于不获取相对应的函数返回值\n",
    "            \n",
    "            # 只关注最后一个的预测  (B,T,C)\n",
    "            logits = logits[:, -1, :] # becomes (B, C)\n",
    "            # 对概率值应用softmax\n",
    "            probs = F.softmax(logits, dim=-1) # (B, C)\n",
    "            # nn.argmax\n",
    "            # 对input的每一行做n_samples次取值，输出的张量是每一次取值时input张量对应行的下标，也即找到概率值输出最大的下标，也对应着最大的编码\n",
    "            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n",
    "            # 将新产生的编码加入到之前的编码中，形成新的编码\n",
    "            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n",
    "\n",
    "        return idx "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
